{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f906ef93",
   "metadata": {},
   "source": [
    "# Training a Model on Narrowband for Classification\n",
    "\n",
    "This notebook demonstrates how to train a PyTorch model on the Narrowband dataset for modulation recognition.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "20666ce4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Variables\n",
    "from torchsig.signals.signal_lists import TorchSigSignalLists\n",
    "from torchsig.transforms.dataset_transforms import ComplexTo2D\n",
    "from torchsig.transforms.target_transforms import ClassIndex\n",
    "\n",
    "root = \"./datasets/narrowband_classifier_example\"\n",
    "fft_size = 256\n",
    "num_iq_samples_dataset = fft_size ** 2\n",
    "class_list = TorchSigSignalLists.all_signals\n",
    "num_classes = len(class_list)\n",
    "num_samples_train = len(class_list) * 10 # roughly 10 samples per class\n",
    "num_samples_val = len(class_list) * 2\n",
    "impairment_level = 0\n",
    "seed = 123456789\n",
    "\n",
    "# ComplexTo2D turns a IQ array of complex values into a 2D array, with one channel for the real component, while the other is for the imaginary component\n",
    "transforms = [ComplexTo2D()]\n",
    "# ClassIndex turns our target labels into the index of the class according to class_list\n",
    "target_transforms = [ClassIndex()]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e1c3e43",
   "metadata": {},
   "source": [
    "## Create the Narrowband Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "442b71b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsig.datasets.dataset_metadata import NarrowbandMetadata\n",
    "from torchsig.datasets.datamodules import NarrowbandDataModule\n",
    "\n",
    "dataset_metadata = NarrowbandMetadata(\n",
    "    num_iq_samples_dataset = num_iq_samples_dataset,\n",
    "    fft_size = fft_size,\n",
    "    impairment_level = impairment_level,\n",
    "    class_list = class_list,\n",
    "    seed = seed\n",
    ")\n",
    "\n",
    "narrowband_datamodule = NarrowbandDataModule(\n",
    "    root = root,\n",
    "    dataset_metadata = dataset_metadata,\n",
    "    num_samples_train = num_samples_train,\n",
    "    num_samples_val = num_samples_val,\n",
    "    transforms = transforms,\n",
    "    target_transforms = target_transforms,\n",
    "    create_batch_size = 4,\n",
    "    create_num_workers = 4,\n",
    "    batch_size=4,\n",
    "    num_workers=4,\n",
    ")\n",
    "narrowband_datamodule.prepare_data()\n",
    "narrowband_datamodule.setup()\n",
    "\n",
    "data, targets = narrowband_datamodule.train[0]\n",
    "print(f\"Data shape: {data.shape}\")\n",
    "print(f\"Targets: {targets}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c54d248",
   "metadata": {},
   "source": [
    "## Create the Model\n",
    "\n",
    "We use our own XCIT model code and utils, but this can be replaced with your own model arhcitecture in PyTorch, Ultralytics, timm, ect."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d9a89ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsig.models import XCiTClassifier\n",
    "from torchinfo import summary\n",
    "\n",
    "model = XCiTClassifier(\n",
    "    input_channels=2,\n",
    "    num_classes=num_classes,\n",
    ")\n",
    "summary(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1826b1b8",
   "metadata": {},
   "source": [
    "## Train the Model\n",
    "\n",
    "Using the [Pytorch Lightning Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html), we can train our model for modulation recognition on Narrowband IQ dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fe3c33d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "num_epochs = 10\n",
    "\n",
    "trainer = pl.Trainer(\n",
    "    max_epochs = num_epochs,\n",
    "    accelerator =  'gpu' if torch.cuda.is_available() else 'cpu',\n",
    "    devices = 1\n",
    ")\n",
    "# print(trainer)\n",
    "\n",
    "trainer.fit(model, narrowband_datamodule)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ec35214",
   "metadata": {},
   "source": [
    "## Test the Model\n",
    "\n",
    "Now that we've trained the model, we can test its predictions on a new dataset (not used in training)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b89805f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchsig.datasets.narrowband import NewNarrowband, StaticNarrowband\n",
    "from torchsig.utils.writer import DatasetCreator\n",
    "import torch\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "test_dataset_size = 10\n",
    "\n",
    "dataset_metadata_test = NarrowbandMetadata(\n",
    "    num_iq_samples_dataset = num_iq_samples_dataset,\n",
    "    fft_size = fft_size,\n",
    "    impairment_level = impairment_level,\n",
    "    class_list = class_list,\n",
    "    num_samples=test_dataset_size,\n",
    "    transforms=transforms,\n",
    "    target_transforms=target_transforms,\n",
    "    seed = 123456788 # different than train\n",
    ")\n",
    "# print(dataset_metadata_test)\n",
    "\n",
    "dc = DatasetCreator(\n",
    "    dataset = NewNarrowband(\n",
    "        dataset_metadata = dataset_metadata_test,\n",
    "    ),\n",
    "    root = f\"{root}/test\",\n",
    "    overwrite=True,\n",
    "    batch_size=1,\n",
    "    num_workers=1,\n",
    ")\n",
    "dc.create()\n",
    "\n",
    "test_narrowband = StaticNarrowband(\n",
    "    root = f\"{root}/test\",\n",
    "    impaired = impairment_level > 0,\n",
    ")\n",
    "\n",
    "\n",
    "data, class_index = test_narrowband[0]\n",
    "print(f\"Data shape: {data.shape}\")\n",
    "print(f\"Targets: {targets}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93f7135e",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "data, class_index = test_narrowband[0]\n",
    "# move to model to the same device as the data\n",
    "model.to(device)\n",
    "# turn the model into evaluation mode\n",
    "model.eval()\n",
    "with torch.no_grad(): # do not update model weights\n",
    "    # convert to tensor and add a batch dimension\n",
    "    data = torch.from_numpy(data).to(device).unsqueeze(dim=0)\n",
    "    # have model predict data\n",
    "    # returns a probability the data is each signal class\n",
    "    pred = model(data)\n",
    "    # print(pred) # if you want to see the list of probabilities\n",
    "\n",
    "    # choose the class with highest confidence\n",
    "    predicted_class = torch.argmax(pred).cpu().numpy()\n",
    "    print(f\"Predicted = {predicted_class} ({class_list[predicted_class]})\")\n",
    "    print(f\"Actual = {class_index} ({class_list[class_index]})\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed77bdce",
   "metadata": {},
   "outputs": [],
   "source": [
    "# We can do this over the whole test dataset to check to accurarcy of our model\n",
    "num_correct = 0\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "for sample in test_narrowband:\n",
    "    data, actual_class = sample\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        data = torch.from_numpy(data).to(device).unsqueeze(dim=0)\n",
    "        pred = model(data)\n",
    "        predicted_class = torch.argmax(pred).cpu().numpy()\n",
    "        if predicted_class == actual_class:\n",
    "            num_correct += 1\n",
    "\n",
    "# try increasing num_epochs or train dataset size to increase accuracy\n",
    "print(f\"Correct Predictions = {num_correct}\")\n",
    "print(f\"Percent Correct = {num_correct / len(test_narrowband)}%\")"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "py:percent,ipynb"
  },
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
